/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.datashark.server.master.runner;

import static com.datashark.common.constants.Constants.DEFAULT_WORKER_GROUP;
import static com.datashark.common.enums.TaskEventType.DISPATCH;

import com.datashark.common.constants.Constants;
import com.datashark.common.enums.Flag;
import com.datashark.common.enums.Priority;
import com.datashark.common.enums.TaskEventType;
import com.datashark.common.thread.ThreadUtils;
import com.datashark.dao.entity.Environment;
import com.datashark.dao.entity.ProcessDefinition;
import com.datashark.dao.entity.ProcessTaskRelation;
import com.datashark.dao.entity.TaskDefinition;
import com.datashark.dao.entity.TaskInstance;
import com.datashark.dao.mapper.ProcessTaskRelationMapper;
import com.datashark.dao.repository.TaskInstanceDao;
import com.datashark.dao.utils.TaskInstanceUtils;
import com.datashark.extract.base.client.SingletonJdkDynamicRpcClientProxyFactory;
import com.datashark.extract.master.transportor.StreamingTaskTriggerRequest;
import com.datashark.extract.worker.ITaskInstanceExecutionEventAckListener;
import com.datashark.extract.worker.transportor.TaskInstanceExecutionFinishEventAck;
import com.datashark.extract.worker.transportor.TaskInstanceExecutionInfoEventAck;
import com.datashark.extract.worker.transportor.TaskInstanceExecutionRunningEventAck;
import com.datashark.plugin.task.api.TaskChannel;
import com.datashark.plugin.task.api.TaskExecutionContext;
import com.datashark.plugin.task.api.TaskPluginManager;
import com.datashark.plugin.task.api.enums.TaskExecutionStatus;
import com.datashark.plugin.task.api.model.Property;
import com.datashark.plugin.task.api.parameters.AbstractParameters;
import com.datashark.plugin.task.api.parameters.ParametersNode;
import com.datashark.plugin.task.api.parameters.resource.ResourceParametersHelper;
import com.datashark.plugin.task.api.utils.LogUtils;
import com.datashark.plugin.task.api.utils.ParameterUtils;
import com.datashark.server.master.builder.TaskExecutionContextBuilder;
import com.datashark.server.master.cache.StreamTaskInstanceExecCacheManager;
import com.datashark.server.master.config.MasterConfig;
import com.datashark.server.master.event.StateEventHandleError;
import com.datashark.server.master.event.StateEventHandleException;
import com.datashark.server.master.metrics.TaskMetrics;
import com.datashark.server.master.processor.queue.TaskEvent;
import com.datashark.server.master.runner.dispatcher.WorkerTaskDispatcher;
import com.datashark.server.master.runner.execute.DefaultTaskExecuteRunnableFactory;
import com.datashark.service.bean.SpringApplicationContext;
import com.datashark.service.process.ProcessService;

import org.apache.commons.lang3.StringUtils;

import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;

/**
 * stream task execute
 */
@Slf4j
public class StreamTaskExecuteRunnable implements Runnable {

    protected MasterConfig masterConfig;

    protected ProcessService processService;

    protected TaskInstanceDao taskInstanceDao;

    protected DefaultTaskExecuteRunnableFactory defaultTaskExecuteRunnableFactory;

    protected WorkerTaskDispatcher workerTaskDispatcher;

    protected ProcessTaskRelationMapper processTaskRelationMapper;

    protected TaskPluginManager taskPluginManager;

    private StreamTaskInstanceExecCacheManager streamTaskInstanceExecCacheManager;

    protected TaskDefinition taskDefinition;

    protected TaskInstance taskInstance;

    protected ProcessDefinition processDefinition;

    protected StreamingTaskTriggerRequest taskExecuteStartMessage;

    protected TaskExecutionContextFactory taskExecutionContextFactory;

    /**
     * task event queue
     */
    private final ConcurrentLinkedQueue<TaskEvent> taskEvents = new ConcurrentLinkedQueue<>();

    private TaskRunnableStatus taskRunnableStatus = TaskRunnableStatus.CREATED;

    public StreamTaskExecuteRunnable(TaskDefinition taskDefinition,
                                     StreamingTaskTriggerRequest taskExecuteStartMessage) {
        this.processService = SpringApplicationContext.getBean(ProcessService.class);
        this.masterConfig = SpringApplicationContext.getBean(MasterConfig.class);
        this.workerTaskDispatcher = SpringApplicationContext.getBean(WorkerTaskDispatcher.class);
        this.taskPluginManager = SpringApplicationContext.getBean(TaskPluginManager.class);
        this.processTaskRelationMapper = SpringApplicationContext.getBean(ProcessTaskRelationMapper.class);
        this.taskInstanceDao = SpringApplicationContext.getBean(TaskInstanceDao.class);
        this.streamTaskInstanceExecCacheManager =
                SpringApplicationContext.getBean(StreamTaskInstanceExecCacheManager.class);
        this.taskDefinition = taskDefinition;
        this.taskExecuteStartMessage = taskExecuteStartMessage;
        this.taskExecutionContextFactory = SpringApplicationContext.getBean(TaskExecutionContextFactory.class);
        this.defaultTaskExecuteRunnableFactory =
                SpringApplicationContext.getBean(DefaultTaskExecuteRunnableFactory.class);
    }

    public TaskInstance getTaskInstance() {
        return taskInstance;
    }

    @Override
    public void run() {
        // submit task
        processService.updateTaskDefinitionResources(taskDefinition);
        // TODO 刚开始的时候，任务实例是运行中
        taskInstance = newTaskInstance(taskDefinition);

        List<ProcessTaskRelation> processTaskRelationList =
                processTaskRelationMapper.queryByTaskCode(taskDefinition.getCode());
        long processDefinitionCode = processTaskRelationList.get(0).getProcessDefinitionCode();
        int processDefinitionVersion = processTaskRelationList.get(0).getProcessDefinitionVersion();
        processDefinition = processService.findProcessDefinition(processDefinitionCode, processDefinitionVersion);
        taskInstance.setProcessDefine(processDefinition);
        taskInstance.setTaskDefine(taskDefinition);
        taskInstance.setTenantCode(taskExecuteStartMessage.getTenantCode());
        taskInstance.setProcessInstanceName(processDefinition.getName());
        taskInstanceDao.upsertTaskInstance(taskInstance);

        // add cache
        streamTaskInstanceExecCacheManager.cache(taskInstance.getId(), this);

        try {
            DefaultTaskExecuteRunnable taskExecuteRunnable =
                    defaultTaskExecuteRunnableFactory.createTaskExecuteRunnable(taskInstance);

            workerTaskDispatcher.dispatchTask(taskExecuteRunnable);
        } catch (Exception e) {
            log.error("Master dispatch task to worker error, taskInstanceName: {}", taskInstance.getName(), e);
            taskInstance.setState(TaskExecutionStatus.FAILURE);
            taskInstanceDao.upsertTaskInstance(taskInstance);
            return;
        }
        // set started flag
        taskRunnableStatus = TaskRunnableStatus.STARTED;
        log.info("Master success dispatch task to worker, taskInstanceName: {}, worker: {}", taskInstance.getId(),
                taskInstance.getHost());
    }

    public boolean isStart() {
        return TaskRunnableStatus.STARTED == taskRunnableStatus;
    }

    public boolean addTaskEvent(TaskEvent taskEvent) {
        if (taskInstance.getId() != taskEvent.getTaskInstanceId()) {
            log.info("state event would be abounded, taskInstanceId:{}, eventType:{}, state:{}",
                    taskEvent.getTaskInstanceId(), taskEvent.getEvent(), taskEvent.getState());
            return false;
        }
        taskEvents.add(taskEvent);
        return true;
    }

    public int eventSize() {
        return this.taskEvents.size();
    }

    /**
     * handle event
     */
    public void handleEvents() {
        if (!isStart()) {
            log.info(
                    "The stream task instance is not started, will not handle its state event, current state event size: {}",
                    taskEvents.size());
            return;
        }
        TaskEvent taskEvent = null;
        while (!this.taskEvents.isEmpty()) {
            try {
                taskEvent = this.taskEvents.peek();
                LogUtils.setTaskInstanceIdMDC(taskEvent.getTaskInstanceId());

                log.info("Begin to handle state event, {}", taskEvent);
                if (this.handleTaskEvent(taskEvent)) {
                    this.taskEvents.remove(taskEvent);
                }
            } catch (StateEventHandleError stateEventHandleError) {
                log.error("State event handle error, will remove this event: {}", taskEvent, stateEventHandleError);
                this.taskEvents.remove(taskEvent);
                ThreadUtils.sleep(Constants.SLEEP_TIME_MILLIS);
            } catch (StateEventHandleException stateEventHandleException) {
                log.error("State event handle error, will retry this event: {}",
                        taskEvent,
                        stateEventHandleException);
                ThreadUtils.sleep(Constants.SLEEP_TIME_MILLIS);
            } catch (Exception e) {
                // we catch the exception here, since if the state event handle failed, the state event will still keep
                // in the stateEvents queue.
                log.error("State event handle error, get a unknown exception, will retry this event: {}",
                        taskEvent,
                        e);
                ThreadUtils.sleep(Constants.SLEEP_TIME_MILLIS);
            } finally {
                LogUtils.removeWorkflowAndTaskInstanceIdMDC();
            }
        }
    }

    public TaskInstance newTaskInstance(TaskDefinition taskDefinition) {
        TaskInstance taskInstance = new TaskInstance();
        taskInstance.setTaskCode(taskDefinition.getCode());
        taskInstance.setTaskDefinitionVersion(taskDefinition.getVersion());
        taskInstance.setName(taskDefinition.getName());
        // task instance state
        taskInstance.setState(TaskExecutionStatus.SUBMITTED_SUCCESS);
        // set process instance id to 0
        taskInstance.setProcessInstanceId(0);
        taskInstance.setProjectCode(taskDefinition.getProjectCode());
        // task instance type
        taskInstance.setTaskType(taskDefinition.getTaskType().toUpperCase());
        // task instance whether alert
        taskInstance.setAlertFlag(Flag.NO);

        // task instance start time
        taskInstance.setStartTime(null);

        // task instance flag
        taskInstance.setFlag(Flag.YES);

        // task instance current retry times
        taskInstance.setRetryTimes(0);
        taskInstance.setMaxRetryTimes(taskDefinition.getFailRetryTimes());
        taskInstance.setRetryInterval(taskDefinition.getFailRetryInterval());

        // set task param
        taskInstance.setTaskParams(taskDefinition.getTaskParams());

        // set task group and priority
        taskInstance.setTaskGroupId(taskDefinition.getTaskGroupId());
        taskInstance.setTaskGroupPriority(taskDefinition.getTaskGroupPriority());

        // set task cpu quota and max memory
        taskInstance.setCpuQuota(taskDefinition.getCpuQuota());
        taskInstance.setMemoryMax(taskDefinition.getMemoryMax());

        // task instance priority
        taskInstance.setTaskInstancePriority(Priority.MEDIUM);
        if (taskDefinition.getTaskPriority() != null) {
            taskInstance.setTaskInstancePriority(taskDefinition.getTaskPriority());
        }

        // delay execution time
        taskInstance.setDelayTime(taskDefinition.getDelayTime());

        // task dry run flag
        taskInstance.setDryRun(taskExecuteStartMessage.getDryRun());

        taskInstance.setWorkerGroup(StringUtils.isBlank(taskDefinition.getWorkerGroup()) ? DEFAULT_WORKER_GROUP
                : taskDefinition.getWorkerGroup());
        taskInstance.setEnvironmentCode(
                taskDefinition.getEnvironmentCode() == 0 ? -1 : taskDefinition.getEnvironmentCode());

        if (!taskInstance.getEnvironmentCode().equals(-1L)) {
            Environment environment = processService.findEnvironmentByCode(taskInstance.getEnvironmentCode());
            if (Objects.nonNull(environment) && StringUtils.isNotEmpty(environment.getConfig())) {
                taskInstance.setEnvironmentConfig(environment.getConfig());
            }
        }

        if (taskInstance.getSubmitTime() == null) {
            taskInstance.setSubmitTime(new Date());
        }
        if (taskInstance.getFirstSubmitTime() == null) {
            taskInstance.setFirstSubmitTime(taskInstance.getSubmitTime());
        }

        taskInstance.setTaskExecuteType(taskDefinition.getTaskExecuteType());
        taskInstance.setExecutorId(taskExecuteStartMessage.getExecutorId());
        taskInstance.setExecutorName(taskExecuteStartMessage.getExecutorName());

        return taskInstance;
    }

    /**
     * get TaskExecutionContext
     *
     * @param taskInstance taskInstance
     * @return TaskExecutionContext
     */
    protected TaskExecutionContext getTaskExecutionContext(TaskInstance taskInstance) {
        int userId = taskDefinition == null ? 0 : taskDefinition.getUserId();
        String tenantCode = processService.getTenantForProcess(taskExecuteStartMessage.getTenantCode(), userId);

        // verify tenant is null
        if (StringUtils.isBlank(tenantCode)) {
            log.error("tenant not exists,task instance id : {}", taskInstance.getId());
            return null;
        }

        TaskChannel taskChannel = taskPluginManager.getTaskChannel(taskInstance.getTaskType());
        ResourceParametersHelper resources = taskChannel.getResources(taskInstance.getTaskParams());

        AbstractParameters baseParam = taskPluginManager.getParameters(
                ParametersNode.builder()
                        .taskType(taskInstance.getTaskType())
                        .taskParams(taskInstance.getTaskParams())
                        .build());
        Map<String, Property> propertyMap = paramParsingPreparation(taskInstance, baseParam);
        TaskExecutionContext taskExecutionContext = TaskExecutionContextBuilder.get()
                .buildWorkflowInstanceHost(masterConfig.getMasterAddress())
                .buildTaskInstanceRelatedInfo(taskInstance)
                .buildTaskDefinitionRelatedInfo(taskDefinition)
                .buildResourceParametersInfo(resources)
                .buildBusinessParamsMap(new HashMap<>())
                .buildParamInfo(propertyMap)
                .create();

        taskExecutionContext.setTenantCode(tenantCode);
        taskExecutionContext.setProjectCode(processDefinition.getProjectCode());
        taskExecutionContext.setProcessDefineCode(processDefinition.getCode());
        taskExecutionContext.setProcessDefineVersion(processDefinition.getVersion());
        // process instance id default 0
        taskExecutionContext.setProcessInstanceId(0);
        taskExecutionContextFactory.setDataQualityTaskExecutionContext(taskExecutionContext, taskInstance, tenantCode);
        taskExecutionContextFactory.setK8sTaskRelatedInfo(taskExecutionContext, taskInstance);
        return taskExecutionContext;
    }

    protected boolean handleTaskEvent(TaskEvent taskEvent) throws StateEventHandleException, StateEventHandleError {
        if (taskEvent.getEvent() == DISPATCH) {
            return true;
        }

        try {
            switch (taskEvent.getEvent()) {
                case RUNNING:
                    handleRunningEvent(taskEvent);
                    break;
                case UPDATE_PID:
                    handleUpdatePidEvent(taskEvent);
                    break;
                case RESULT:
                    handleResultEvent(taskEvent);
                    break;
                default:
                    log.warn("Unhandled task event type: {}", taskEvent.getEvent());
                    break;
            }

            recordTaskStateMetrics(taskEvent);

            sendAckToWorker(taskEvent);

            if (taskInstance.getState().isFinished()) {
                streamTaskInstanceExecCacheManager.removeByTaskInstanceId(taskInstance.getId());
                log.info("The stream task instance is finished, taskInstanceId:{}, state:{}",
                        taskInstance.getId(),
                        taskEvent.getState());
            }

            return true;
        } catch (Exception ex) {
            handleException(taskEvent, ex);
            return false;
        }
    }

    private void handleException(TaskEvent taskEvent, Exception ex) throws StateEventHandleError {
        TaskInstance oldTaskInstance = new TaskInstance();
        TaskInstanceUtils.copyTaskInstance(taskInstance, oldTaskInstance);
        TaskInstanceUtils.copyTaskInstance(oldTaskInstance, taskInstance);

        if (ex instanceof StateEventHandleError) {
            throw (StateEventHandleError) ex;
        }
        throw new StateEventHandleError("Handle stream task event error, update taskInstance to db failed", ex);
    }

    private void handleDispatchEvent(TaskEvent taskEvent) throws StateEventHandleError {
        updateTaskInstanceWithRollback(taskInstance -> {
            taskInstance.setState(TaskExecutionStatus.DISPATCH);
            taskInstance.setHost(taskEvent.getWorkerAddress());
        });
    }

    private void handleRunningEvent(TaskEvent taskEvent) throws StateEventHandleError {
        updateTaskInstanceWithRollback(taskInstance -> {
            taskInstance.setState(taskEvent.getState());
            taskInstance.setStartTime(taskEvent.getStartTime());
            taskInstance.setHost(taskEvent.getWorkerAddress());
            taskInstance.setLogPath(taskEvent.getLogPath());
            taskInstance.setExecutePath(taskEvent.getExecutePath());
        });
    }

    private void handleUpdatePidEvent(TaskEvent taskEvent) throws StateEventHandleError {
        updateTaskInstanceWithRollback(taskInstance -> {
            taskInstance.setPid(taskEvent.getProcessId());
            taskInstance.setAppLink(taskEvent.getAppIds());
        });
    }

    private void handleResultEvent(TaskEvent taskEvent) throws StateEventHandleError {
        updateTaskInstanceWithRollback(taskInstance -> {
            taskInstance.setStartTime(taskEvent.getStartTime());
            taskInstance.setHost(taskEvent.getWorkerAddress());
            taskInstance.setLogPath(taskEvent.getLogPath());
            taskInstance.setExecutePath(taskEvent.getExecutePath());
            taskInstance.setPid(taskEvent.getProcessId());
            taskInstance.setAppLink(taskEvent.getAppIds());
            taskInstance.setState(taskEvent.getState());
            taskInstance.setEndTime(taskEvent.getEndTime());
            taskInstance.setVarPool(taskEvent.getVarPool());
            processService.changeOutParam(taskInstance);
        });
    }

    private void updateTaskInstanceWithRollback(Consumer<TaskInstance> updateAction) throws StateEventHandleError {
        TaskInstance oldTaskInstance = new TaskInstance();
        TaskInstanceUtils.copyTaskInstance(taskInstance, oldTaskInstance);

        try {
            updateAction.accept(taskInstance);
            if (!taskInstanceDao.updateById(taskInstance)) {
                throw new StateEventHandleError("Update taskInstance to db failed");
            }
        } catch (Exception ex) {
            TaskInstanceUtils.copyTaskInstance(oldTaskInstance, taskInstance);
            if (ex instanceof StateEventHandleError) {
                throw (StateEventHandleError) ex;
            }
            throw new StateEventHandleError("Handle stream task event error, update taskInstance to db failed", ex);
        }
    }

    private void recordTaskStateMetrics(TaskEvent taskEvent) {
        if (taskEvent == null) {
            // the event is broken
            log.warn("The task event is broken..., taskEvent: {}", taskEvent);
            return;
        }
        if (taskEvent.getState().isFinished()) {
            TaskMetrics.incTaskInstanceByState("finish");
        }
        switch (taskEvent.getState()) {
            case KILL:
                TaskMetrics.incTaskInstanceByState("stop");
                break;
            case SUCCESS:
                TaskMetrics.incTaskInstanceByState("success");
                break;
            case FAILURE:
                TaskMetrics.incTaskInstanceByState("fail");
                break;
            default:
                break;
        }
    }

    public Map<String, Property> paramParsingPreparation(@NonNull TaskInstance taskInstance,
                                                         @NonNull AbstractParameters parameters) {
        // assign value to definedParams here
        Map<String, String> globalParamsMap = taskExecuteStartMessage.getStartParams();
        Map<String, Property> globalParams = ParameterUtils.getUserDefParamsMap(globalParamsMap);

        // combining local and global parameters
        Map<String, Property> localParams = parameters.getInputLocalParametersMap();

        // stream pass params
        parameters.setVarPool(taskInstance.getVarPool());
        Map<String, Property> varParams = parameters.getVarPoolMap();

        if (globalParams.isEmpty() && localParams.isEmpty() && varParams.isEmpty()) {
            return null;
        }

        if (varParams.size() != 0) {
            globalParams.putAll(varParams);
        }
        if (localParams.size() != 0) {
            globalParams.putAll(localParams);
        }

        return globalParams;
    }

    private void sendAckToWorker(TaskEvent taskEvent) {
        // If event handle success, send ack to worker to otherwise the worker will retry this event
        ITaskInstanceExecutionEventAckListener instanceExecutionEventAckListener =
                SingletonJdkDynamicRpcClientProxyFactory
                        .getProxyClient(taskEvent.getWorkerAddress(), ITaskInstanceExecutionEventAckListener.class);
        if (taskEvent.getEvent() == TaskEventType.RUNNING) {
            log.error("taskEvent.getChannel() is null, taskEvent:{}", taskEvent);
            instanceExecutionEventAckListener.handleTaskInstanceExecutionRunningEventAck(
                    TaskInstanceExecutionRunningEventAck.success(taskEvent.getTaskInstanceId()));
            return;
        }
        if (taskEvent.getEvent() == TaskEventType.RESULT) {
            instanceExecutionEventAckListener.handleTaskInstanceExecutionFinishEventAck(
                    TaskInstanceExecutionFinishEventAck.success(taskEvent.getTaskInstanceId()));
            return;
        }

        if (taskEvent.getEvent() == TaskEventType.UPDATE_PID) {
            instanceExecutionEventAckListener.handleTaskInstanceExecutionInfoEventAck(
                    TaskInstanceExecutionInfoEventAck.success(taskEvent.getTaskInstanceId()));
            return;
        }
        log.warn("SendAckToWorker error, get an unknown event: {}", taskEvent);
    }

    private enum TaskRunnableStatus {
        CREATED, STARTED,
        ;
    }
}
